import numpy as np
import torch

task = "stsb1"
a = 3 if task.startswith("mnli") else 1 if task == "stsb" else 2
print(a)

preds = np.array([1, 2, 3])
label = np.array([3, 2, 1])
print(preds == label)
print((preds == label).mean())

from sklearn import metrics

print(metrics.accuracy_score(preds, label))

print("=" * 5, "tensor", "=" * 5)
p1 = torch.tensor([1, 2, 3])
l2 = torch.tensor([3, 2, 1])
print(p1 == l2)
print((p1 == l2).type(torch.float64).mean())
print(metrics.accuracy_score(p1, l2))
